/*
	Rijndael Block Cipher - rijndael.c

	Written by Mike Scott 21st April 1999
	mike@compapp.dcu.ie

	Permission for free direct or derivative use is granted subject
	to compliance with any conditions that the originators of the
	algorithm place on its exploitation.
*/
#include <string.h>
#include <gctypes.h>

#include "rijndael.h"

/* rotates x one bit to the left */
#define ROTL(x) (((x)>>7)|((x)<<1))

/* Rotates 32-bit word left by 1, 2 or 3 byte  */
#define ROTL8(x) (((x)<<8)|((x)>>24))
#define ROTL16(x) (((x)<<16)|((x)>>16))
#define ROTL24(x) (((x)<<24)|((x)>>8))

/* Fixed Data */
static u8 InCo[4]= { 0xB, 0xD, 0x9, 0xE };  /* Inverse Coefficients */

static u8 fbsub[256];
static u8 rbsub[256];
static u8 ptab[256], ltab[256];
static u32 ftable[256];
static u32 rtable[256];
static u32 rco[30];

/* Parameter-dependent data */
static int Nk, Nb, Nr;
static u8 fi[24], ri[24];
static u32 fkey[120];
static u32 rkey[120];

/* pack bytes into a 32-bit Word */
static u32 pack(u8 *b) {
	return ((u32)b[3]<<24)|((u32)b[2]<<16)|((u32)b[1]<<8)|(u32)b[0];
}

/* unpack bytes from a word */
static void unpack(u32 a, u8 *b) {
	b[0] = (u8)a;
	b[1] = (u8)(a >> 8);
	b[2] = (u8)(a >> 16);
	b[3] = (u8)(a >> 24);
}

static u8 xtime(u8 a) {
	u8 b;
	if (a & 0x80) b = 0x1B;
	else b = 0;
	a <<= 1;
	a ^= b;
	return a;
}

/* x.y = AntiLog(Log(x) + Log(y)) */
static u8 bmul(u8 x, u8 y) {
	if (x && y) return ptab[(ltab[x] + ltab[y]) % 255];
	else return 0;
}

static u32 SubByte(u32 a) {
	u8 b[4];
	unpack(a, b);
	b[0] = fbsub[b[0]];
	b[1] = fbsub[b[1]];
	b[2] = fbsub[b[2]];
	b[3] = fbsub[b[3]];
	return pack(b);
}

/* dot product of two 4-byte arrays */
static u8 product(u32 x, u32 y) {
	u8 xb[4], yb[4];
	unpack(x, xb);
	unpack(y, yb);
	return bmul(xb[0], yb[0]) ^ bmul(xb[1], yb[1]) ^ bmul(xb[2], yb[2]) ^ bmul(xb[3], yb[3]);
}

/* matrix Multiplication */
static u32 InvMixCol(u32 x) {
	u32 m;
	u8 b[4];

	m = pack(InCo);
	b[3] = product(m, x);
	m = ROTL24(m);
	b[2] = product(m, x);
	m = ROTL24(m);
	b[1] = product(m, x);
	m = ROTL24(m);
	b[0] = product(m, x);
	return pack(b);
}

/* multiplicative inverse */
static u8 ByteSub(u8 x) {
	u8 y = ptab[255 - ltab[x]];
	x = y;  x = ROTL(x);
	y ^= x; x = ROTL(x);
	y ^= x; x = ROTL(x);
	y ^= x; x = ROTL(x);
	y ^= x; y ^= 0x63;
	return y;
}

/* generate tables */
static void gentables() {
	int i;
	u8 y, b[4];

	/* use 3 as primitive root to generate power and log tables */

	ltab[0] = 0;
	ptab[0] = 1; ltab[1] = 0;
	ptab[1] = 3; ltab[3] = 1;
	for (i = 2; i < 256; i++) {
		ptab[i] = ptab[i-1] ^ xtime(ptab[i - 1]);
		ltab[ptab[i]] = i;
	}

	/* affine transformation:- each bit is xored with itself shifted one bit */

	fbsub[0] = 0x63;
	rbsub[0x63] = 0;
	for (i = 1; i < 256; i++) {
		y = ByteSub((u8)i);
		fbsub[i] = y; rbsub[y] = i;
	}

	for (i = 0, y = 1; i < 30; i++) {
		rco[i] = y;
		y = xtime(y);
	}

	/* calculate forward and reverse tables */
	for (i = 0; i < 256; i++) {
		y = fbsub[i];
		b[3] = y ^ xtime(y); b[2] = y;
		b[1] = y;			b[0] = xtime(y);
		ftable[i] = pack(b);

		y = rbsub[i];
		b[3] = bmul(InCo[0], y); b[2] = bmul(InCo[1], y);
		b[1] = bmul(InCo[2], y); b[0] = bmul(InCo[3], y);
		rtable[i] = pack(b);
	}
}

/* blocksize=32*nb bits. Key=32*nk bits */
static void gkey(int nb, int nk, u8 *key) {
	/* currently nb,bk = 4, 6 or 8		  */
	/* key comes as 4*Nk bytes			  */
	/* Key Scheduler. Create expanded encryption key */
	int i, j, k, m, N;
	int C1, C2, C3;
	u32 CipherKey[8];

	Nb = nb; Nk = nk;

	/* Nr is number of rounds */
	if (Nb >= Nk) Nr = 6 + Nb;
	else Nr = 6 + Nk;

	C1 = 1;
	if (Nb < 8) { C2 = 2; C3 = 3; }
	else { C2 = 3; C3 = 4; }

	/* pre-calculate forward and reverse increments */
	for (m = j = 0; j < nb; j++, m += 3) {
		fi[m] = (j + C1) % nb;
		fi[m +1 ] = (j + C2) % nb;
		fi[m +2 ] = (j + C3) % nb;
		ri[m] = (nb + j - C1) % nb;
		ri[m + 1] = (nb +j - C2) % nb;
		ri[m + 2] = (nb +j - C3) % nb;
	}

	N = Nb * (Nr + 1);

	for (i = j = 0; i < Nk; i++, j += 4) {
		CipherKey[i] = pack(key + j);
	}
	for (i = 0; i < Nk; i++) fkey[i] = CipherKey[i];
	for ( j =Nk, k = 0; j < N; j+= Nk, k++) {
		fkey[j] = fkey[j - Nk] ^ SubByte(ROTL24(fkey[j - 1])) ^ rco[k];
		if (Nk <= 6) {
			for (i = 1; i < Nk && (i + j) < N; i++)
				fkey[i + j] = fkey[i + j - Nk] ^ fkey[i + j - 1];
		} else {
			for (i = 1; i < 4 && (i + j) < N; i++)
				fkey[i + j] = fkey[i + j - Nk] ^ fkey[i + j - 1];
			if ((j + 4) < N) fkey[j + 4] = fkey[j + 4 - Nk] ^ SubByte(fkey[j + 3]);
			for (i = 5; i < Nk && (i + j) < N; i++)
				fkey[i + j] = fkey[i + j - Nk] ^ fkey[i + j - 1];
		}
	}

	/* now for the expanded decrypt key in reverse order */

	for (j = 0; j < Nb; j++) rkey[j + N - Nb] = fkey[j];
	for (i = Nb; i < N - Nb; i += Nb) {
		k = N - Nb - i;
		for (j = 0; j < Nb; j++) rkey[k + j] = InvMixCol(fkey[i + j]);
	}
	for (j = N - Nb; j < N; j++) rkey[j - N + Nb] = fkey[j];
}


/* There is an obvious time/space trade-off possible here.	 *
 * Instead of just one ftable[], I could have 4, the other	 *
 * 3 pre-rotated to save the ROTL8, ROTL16 and ROTL24 overhead */

static void encrypt(u8 *buff) {
	int i, j, k, m;
	u32 a[8], b[8], *x, *y, *t;

	for (i = j = 0; i < Nb; i++, j += 4) {
		a[i] = pack(buff + j);
		a[i] ^= fkey[i];
	}
	k = Nb;
	x = a; y = b;

	/* State alternates between a and b */
	/* Nr is number of rounds. May be odd. */
	for (i = 1; i < Nr; i++) {

		/* if Nb is fixed - unroll this next loop and hard-code in the values of fi[]  */

		/* deal with each 32-bit element of the State */
		for (m = j = 0; j < Nb; j++, m += 3) {
			/* This is the time-critical bit */
			y[j] = fkey[k++] ^ ftable[(u8)x[j]] ^
				ROTL8(ftable[(u8)(x[fi[m]] >> 8)]) ^
				ROTL16(ftable[(u8)(x[fi[m + 1]] >> 16)]) ^
				ROTL24(ftable[x[fi[m + 2]] >> 24]);
		}
		t = x; x = y; y = t;	  /* swap pointers */
	}

	/* Last Round - unroll if possible */
	for (m = j = 0; j < Nb; j++, m += 3) {
		y[j] = fkey[k++] ^ (u32)fbsub[(u8)x[j]] ^
			ROTL8((u32)fbsub[(u8)(x[fi[m]] >> 8)]) ^
			ROTL16((u32)fbsub[(u8)(x[fi[m + 1]] >> 16)]) ^
			ROTL24((u32)fbsub[x[fi[m + 2]] >> 24]);
	}
	for (i = j = 0; i < Nb; i++, j += 4) {
		unpack(y[i], buff + j);
		x[i] = y[i] = 0;   /* clean up stack */
	}
}

static void decrypt(u8 *buff) {
	int i, j, k, m;
	u32 a[8], b[8], *x, *y, *t;

	for (i = j = 0; i < Nb; i++, j += 4) {
		a[i] = pack(buff + j);
		a[i] ^= rkey[i];
	}
	k = Nb;
	x = a; y = b;

	/* State alternates between a and b */
	/* Nr is number of rounds. May be odd. */
	for (i = 1; i < Nr; i++) {

		/* if Nb is fixed - unroll this next loop and hard-code in the values of ri[]  */

		for (m=j=0;j<Nb;j++,m+=3) {
			/* This is the time-critical bit */
			y[j] = rkey[k++] ^ rtable[(u8)x[j]] ^
				ROTL8(rtable[(u8)(x[ri[m]] >> 8)]) ^
				ROTL16(rtable[(u8)(x[ri[m + 1]] >> 16)]) ^
				ROTL24(rtable[x[ri[m + 2]] >> 24]);
		}
		t = x; x = y; y = t;	  /* swap pointers */
	}

	/* Last Round - unroll if possible */
	for (m = j = 0; j < Nb; j++, m += 3) {
		y[j] = rkey[k++] ^ (u32)rbsub[(u8)x[j]] ^
			ROTL8((u32)rbsub[(u8)(x[ri[m]]>>8)]) ^
			ROTL16((u32)rbsub[(u8)(x[ri[m + 1]] >> 16)]) ^
			ROTL24((u32)rbsub[x[ri[m + 2]] >> 24]);
	}
	for (i = j = 0; i < Nb; i++, j += 4) {
		unpack(y[i], buff + j);
		x[i] = y[i] = 0;   /* clean up stack */
	}
}

void aes_set_key(u8 *key) {
  gentables();
  gkey(4, 4, key);
}

// CBC mode decryption
void aes_decrypt(u8 *iv, u8 *inbuf, u8 *outbuf, unsigned long long len) {
	u8 block[16];
	unsigned int blockno = 0, i;

	for (blockno = 0; blockno <= (len / sizeof(block)); blockno++) {
		unsigned int fraction;
		if (blockno == (len / sizeof(block))) { // last block
			fraction = len % sizeof(block);
			if (fraction == 0) break;
			memset(block, 0, sizeof(block));
		} else fraction = 16;

		memcpy(block, inbuf + blockno * sizeof(block), fraction);
		decrypt(block);
		u8 *ctext_ptr;
		if (blockno == 0) ctext_ptr = iv;
		else ctext_ptr = inbuf + (blockno - 1) * sizeof(block);

		for(i = 0; i < fraction; i++) outbuf[blockno * sizeof(block) + i] = ctext_ptr[i] ^ block[i];
	}
}

// CBC mode encryption
void aes_encrypt(u8 *iv, u8 *inbuf, u8 *outbuf, unsigned long long len) {
	u8 block[16];
	unsigned int blockno = 0, i;

	for (blockno = 0; blockno <= (len / sizeof(block)); blockno++) {
		unsigned int fraction;
		if (blockno == (len / sizeof(block))) { // last block
			fraction = len % sizeof(block);
			if (fraction == 0) break;
			memset(block, 0, sizeof(block));
		} else fraction = 16;

		memcpy(block, inbuf + blockno * sizeof(block), fraction);

		for(i = 0; i < fraction; i++) block[i] = inbuf[blockno * sizeof(block) + i] ^ iv[i];

		encrypt(block);
		memcpy(iv, block, sizeof(block));
		memcpy(outbuf + blockno * sizeof(block), block, sizeof(block));
	}
}
